use rama_core::context::Extensions;
use std::{fmt, io};

use crate::tls::{
    CipherSuite, ECPointFormat, ExtensionId, ProtocolVersion, SecureTransport, SupportedGroup,
    client::NegotiatedTlsParameters,
};

use super::ClientHelloProvider;

#[derive(Debug, Clone)]
/// Data which can be hashed using [`Self::hash`],
/// and which is also displayed as a "ja3" hash.
///
/// Computed using [`Ja3::compute`].
pub struct Ja3 {
    version: ProtocolVersion,
    cipher_suites: Vec<CipherSuite>,
    extensions: Option<Vec<ExtensionId>>,
    supported_groups: Option<Vec<SupportedGroup>>,
    ec_point_formats: Option<Vec<ECPointFormat>>,
}

impl Ja3 {
    /// Compute the [`Ja3`] (hash).
    ///
    /// As specified by <https://github.com/salesforce/ja3`>.
    pub fn compute(ext: &Extensions) -> Result<Self, Ja3ComputeError> {
        let client_hello = ext
            .get::<SecureTransport>()
            .and_then(|st| st.client_hello())
            .ok_or(Ja3ComputeError::MissingClientHello)?;
        let negotiated_tls_version = ext
            .get::<NegotiatedTlsParameters>()
            .map(|param| param.protocol_version);
        Self::compute_from_client_hello(client_hello, negotiated_tls_version)
    }

    /// Compute the [`Ja3`] (hash) from a reference to either a
    /// [`ClientHello`] or a [`ClientConfig`] data structure.
    ///
    /// In case your source is [`Extensions`] you can use [`Self::compute`] instead.
    ///
    /// [`ClientHello`]: crate::tls::client::ClientHello
    /// [`ClientConfig`]: crate::tls::client::ClientConfig
    pub fn compute_from_client_hello(
        client_hello: impl ClientHelloProvider,
        negotiated_tls_version: Option<ProtocolVersion>,
    ) -> Result<Self, Ja3ComputeError> {
        let version = negotiated_tls_version.unwrap_or_else(|| {
            tracing::trace!(
                "negotiated tls protocol version missing: fallback to client hello tls"
            );
            client_hello.protocol_version()
        });

        let cipher_suites: Vec<_> = client_hello
            .cipher_suites()
            .filter(|c| !c.is_grease())
            .collect();
        if cipher_suites.is_empty() {
            return Err(Ja3ComputeError::EmptyCipherSuites);
        }

        let mut extensions = None;
        let mut supported_groups = None;
        let mut ec_point_formats = None;

        let ce_extensions = client_hello.extensions();
        for ext in ce_extensions {
            if ext.id().is_grease() {
                continue;
            }

            extensions.get_or_insert_with(Vec::default).push(ext.id());

            match ext {
                crate::tls::client::ClientHelloExtension::SupportedGroups(vec) => {
                    let vec: Vec<_> = vec.iter().filter(|g| !g.is_grease()).copied().collect();
                    if !vec.is_empty() {
                        supported_groups = Some(vec)
                    }
                }
                crate::tls::client::ClientHelloExtension::ECPointFormats(vec)
                    if !vec.is_empty() =>
                {
                    ec_point_formats = Some(vec.clone())
                }
                _ => (),
            }
        }

        Ok(Self {
            version,
            cipher_suites,
            extensions,
            supported_groups,
            ec_point_formats,
        })
    }

    #[inline]
    /// compute the "ja3" hash from this [`Ja3`] data structure as a String.
    pub fn hash(&self) -> String {
        format!("{self:x}")
    }

    /// compute the "ja3" hash from this [`Ja3`] data structure into the writer.
    fn hash_to(&self, w: &mut impl fmt::Write, lower: bool) -> fmt::Result {
        let mut ctx = md5::Context::new();
        let _ = self.write_to_io(&mut ctx).inspect_err(|err| {
            if cfg!(debug_assertions) {
                panic!("md5 ingest failed: {err:?}");
            }
        });
        let digest = ctx.compute();
        if lower {
            write!(w, "{digest:x}",)?;
        } else {
            write!(w, "{digest:X}",)?;
        }
        Ok(())
    }
}

macro_rules! impl_write_to {
    ($w:ident, $this:ident) => {{
        write!($w, "{}", u16::from($this.version))?;

        let mut sep = ',';
        for cipher_suite in &$this.cipher_suites {
            write!($w, "{sep}{}", u16::from(*cipher_suite))?;
            sep = '-';
        }

        match &$this.extensions {
            Some(ext) => {
                sep = ',';
                for ext in ext {
                    write!($w, "{sep}{}", u16::from(*ext))?;
                    sep = '-';
                }
            }
            None => write!($w, ",")?,
        }

        match &$this.supported_groups {
            Some(supported_groups) => {
                sep = ',';
                for g in supported_groups {
                    write!($w, "{sep}{}", u16::from(*g))?;
                    sep = '-';
                }
            }
            None => write!($w, ",")?,
        }

        match &$this.ec_point_formats {
            Some(ec_point_formats) => {
                sep = ',';
                for p in ec_point_formats {
                    write!($w, "{sep}{}", u8::from(*p))?;
                    sep = '-';
                }
            }
            None => write!($w, ",")?,
        }

        Ok(())
    }};
}

impl Ja3 {
    fn write_to_io(&self, w: &mut impl io::Write) -> io::Result<()> {
        impl_write_to!(w, self)
    }

    fn write_to_fmt(&self, w: &mut impl fmt::Write) -> fmt::Result {
        impl_write_to!(w, self)
    }
}

impl fmt::Display for Ja3 {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.write_to_fmt(f)
    }
}

impl fmt::LowerHex for Ja3 {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.hash_to(f, true)?;
        Ok(())
    }
}

impl fmt::UpperHex for Ja3 {
    #[inline]
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.hash_to(f, false)?;
        Ok(())
    }
}

#[derive(Debug, Clone)]
/// error identifying a failure in [`Ja3::compute`]
pub enum Ja3ComputeError {
    /// missing [`ClientHello`]
    MissingClientHello,
    /// cipher suites was empty
    EmptyCipherSuites,
}

impl fmt::Display for Ja3ComputeError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Ja3ComputeError::MissingClientHello => {
                write!(f, "Ja3 Compute Error: missing client hello")
            }
            Ja3ComputeError::EmptyCipherSuites => {
                write!(f, "Ja3 Compute Error: empty cipher suites")
            }
        }
    }
}

impl std::error::Error for Ja3ComputeError {}

#[cfg(test)]
mod tests {
    use crate::tls::client::parse_client_hello;

    use super::*;

    #[derive(Debug)]
    struct TestCase {
        client_hello: Vec<u8>,
        pcap: &'static str,
        expected_ja3_str: &'static str,
        expected_ja3_hash: &'static str,
    }

    #[test]
    fn test_ja3_compute() {
        // src: <https://github.com/jabedude/ja3-rs/blob/a30d1bea03d2230b1239d437c3f6af7fb7699338/src/lib.rs#L380>
        let test_cases = [
            TestCase {
                client_hello: vec![
                    0x3, 0x3, 0x86, 0xad, 0xa4, 0xcc, 0x19, 0xe7, 0x14, 0x54, 0x54, 0xfd, 0xe7,
                    0x37, 0x33, 0xdf, 0x66, 0xcb, 0xf6, 0xef, 0x3e, 0xc0, 0xa1, 0x54, 0xc6, 0xdd,
                    0x14, 0x5e, 0xc0, 0x83, 0xac, 0xb9, 0xb4, 0xe7, 0x20, 0x1c, 0x64, 0xae, 0xa7,
                    0xa2, 0xc3, 0xe1, 0x8c, 0xd1, 0x25, 0x2, 0x4d, 0xf7, 0x86, 0x4a, 0xc7, 0x19,
                    0xd0, 0xc4, 0xbd, 0xfb, 0x40, 0xc2, 0xef, 0x7f, 0x6d, 0xd3, 0x9a, 0xa7, 0x53,
                    0xdf, 0xdd, 0x0, 0x22, 0x1a, 0x1a, 0x13, 0x1, 0x13, 0x2, 0x13, 0x3, 0xc0, 0x2b,
                    0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, 0xcc, 0xa9, 0xcc, 0xa8, 0xc0, 0x13, 0xc0,
                    0x14, 0x0, 0x9c, 0x0, 0x9d, 0x0, 0x2f, 0x0, 0x35, 0x0, 0xa, 0x1, 0x0, 0x1,
                    0x91, 0xa, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x0, 0x1e, 0x0, 0x0, 0x1b, 0x67,
                    0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x61, 0x64, 0x73, 0x2e, 0x67, 0x2e, 0x64, 0x6f,
                    0x75, 0x62, 0x6c, 0x65, 0x63, 0x6c, 0x69, 0x63, 0x6b, 0x2e, 0x6e, 0x65, 0x74,
                    0x0, 0x17, 0x0, 0x0, 0xff, 0x1, 0x0, 0x1, 0x0, 0x0, 0xa, 0x0, 0xa, 0x0, 0x8,
                    0x9a, 0x9a, 0x0, 0x1d, 0x0, 0x17, 0x0, 0x18, 0x0, 0xb, 0x0, 0x2, 0x1, 0x0, 0x0,
                    0x23, 0x0, 0x0, 0x0, 0x10, 0x0, 0xe, 0x0, 0xc, 0x2, 0x68, 0x32, 0x8, 0x68,
                    0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, 0x0, 0x5, 0x0, 0x5, 0x1, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0xd, 0x0, 0x14, 0x0, 0x12, 0x4, 0x3, 0x8, 0x4, 0x4, 0x1, 0x5,
                    0x3, 0x8, 0x5, 0x5, 0x1, 0x8, 0x6, 0x6, 0x1, 0x2, 0x1, 0x0, 0x12, 0x0, 0x0,
                    0x0, 0x33, 0x0, 0x2b, 0x0, 0x29, 0x9a, 0x9a, 0x0, 0x1, 0x0, 0x0, 0x1d, 0x0,
                    0x20, 0x59, 0x8, 0x6f, 0x41, 0x9a, 0xa5, 0xaa, 0x1d, 0x81, 0xe3, 0x47, 0xf0,
                    0x25, 0x5f, 0x92, 0x7, 0xfc, 0x4b, 0x13, 0x74, 0x51, 0x46, 0x98, 0x8, 0x74,
                    0x3b, 0xde, 0x57, 0x86, 0xe8, 0x2c, 0x74, 0x0, 0x2d, 0x0, 0x2, 0x1, 0x1, 0x0,
                    0x2b, 0x0, 0xb, 0xa, 0xfa, 0xfa, 0x3, 0x4, 0x3, 0x3, 0x3, 0x2, 0x3, 0x1, 0x0,
                    0x1b, 0x0, 0x3, 0x2, 0x0, 0x2, 0xba, 0xba, 0x0, 0x1, 0x0, 0x0, 0x15, 0x0, 0xbd,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                ],
                pcap: "chrome-grease-single.pcap",
                expected_ja3_str: "771,4865-4866-4867-49195-49199-49196-49200-52393-52392-49171-49172-156-157-47-53-10,0-23-65281-10-11-35-16-5-13-18-51-45-43-27-21,29-23-24,0",
                expected_ja3_hash: "66918128f1b9b03303d77c6f2eefd128",
            },
            TestCase {
                client_hello: vec![
                    0x3, 0x3, 0xf6, 0x65, 0xb, 0x22, 0x13, 0xf1, 0xc3, 0xe9, 0xe7, 0xb3, 0xdc, 0x9,
                    0xe4, 0x4b, 0xcb, 0x6e, 0x5, 0xaf, 0x8f, 0x2f, 0x41, 0x8d, 0x15, 0xa8, 0x88,
                    0x46, 0x24, 0x83, 0xca, 0x9, 0x7c, 0x95, 0x20, 0x12, 0xc4, 0x5e, 0x71, 0x8b,
                    0xb9, 0xc9, 0xa9, 0x37, 0x93, 0x4c, 0x41, 0xa6, 0xe8, 0x9e, 0x8f, 0x15, 0x78,
                    0x52, 0xe, 0x3c, 0x28, 0xba, 0xab, 0xa3, 0x34, 0x8b, 0x53, 0x82, 0x83, 0x75,
                    0x24, 0x0, 0x3e, 0x13, 0x2, 0x13, 0x3, 0x13, 0x1, 0xc0, 0x2c, 0xc0, 0x30, 0x0,
                    0x9f, 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x0, 0x9e,
                    0xc0, 0x24, 0xc0, 0x28, 0x0, 0x6b, 0xc0, 0x23, 0xc0, 0x27, 0x0, 0x67, 0xc0,
                    0xa, 0xc0, 0x14, 0x0, 0x39, 0xc0, 0x9, 0xc0, 0x13, 0x0, 0x33, 0x0, 0x9d, 0x0,
                    0x9c, 0x0, 0x3d, 0x0, 0x3c, 0x0, 0x35, 0x0, 0x2f, 0x0, 0xff, 0x1, 0x0, 0x1,
                    0x75, 0x0, 0x0, 0x0, 0x10, 0x0, 0xe, 0x0, 0x0, 0xb, 0x65, 0x78, 0x61, 0x6d,
                    0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0xb, 0x0, 0x4, 0x3, 0x0, 0x1,
                    0x2, 0x0, 0xa, 0x0, 0xc, 0x0, 0xa, 0x0, 0x1d, 0x0, 0x17, 0x0, 0x1e, 0x0, 0x19,
                    0x0, 0x18, 0x33, 0x74, 0x0, 0x0, 0x0, 0x10, 0x0, 0xe, 0x0, 0xc, 0x2, 0x68,
                    0x32, 0x8, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, 0x0, 0x16, 0x0, 0x0,
                    0x0, 0x17, 0x0, 0x0, 0x0, 0xd, 0x0, 0x30, 0x0, 0x2e, 0x4, 0x3, 0x5, 0x3, 0x6,
                    0x3, 0x8, 0x7, 0x8, 0x8, 0x8, 0x9, 0x8, 0xa, 0x8, 0xb, 0x8, 0x4, 0x8, 0x5, 0x8,
                    0x6, 0x4, 0x1, 0x5, 0x1, 0x6, 0x1, 0x3, 0x3, 0x2, 0x3, 0x3, 0x1, 0x2, 0x1, 0x3,
                    0x2, 0x2, 0x2, 0x4, 0x2, 0x5, 0x2, 0x6, 0x2, 0x0, 0x2b, 0x0, 0x9, 0x8, 0x3,
                    0x4, 0x3, 0x3, 0x3, 0x2, 0x3, 0x1, 0x0, 0x2d, 0x0, 0x2, 0x1, 0x1, 0x0, 0x33,
                    0x0, 0x26, 0x0, 0x24, 0x0, 0x1d, 0x0, 0x20, 0x37, 0x98, 0x48, 0x7f, 0x2f, 0xbc,
                    0x86, 0xf9, 0xb8, 0x2, 0xcd, 0x31, 0xf0, 0x4, 0x30, 0xa9, 0x2f, 0x29, 0x61,
                    0xac, 0xec, 0xc9, 0x2f, 0xf7, 0x45, 0xad, 0xd9, 0x67, 0x7, 0x14, 0x62, 0x1,
                    0x0, 0x15, 0x0, 0xb6, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                    0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
                ],
                pcap: "curl.pcap",
                expected_ja3_str: "771,4866-4867-4865-49196-49200-159-52393-52392-52394-49195-49199-158-49188-49192-107-49187-49191-103-49162-49172-57-49161-49171-51-157-156-61-60-53-47-255,0-11-10-13172-16-22-23-13-43-45-51-21,29-23-30-25-24,0-1-2",
                expected_ja3_hash: "456523fc94726331a4d5a2e1d40b2cd7",
            },
        ];
        for test_case in test_cases {
            let mut ext = Extensions::new();
            ext.insert(SecureTransport::with_client_hello(
                parse_client_hello(&test_case.client_hello).expect(test_case.pcap),
            ));

            let ja3 = Ja3::compute(&ext).expect(test_case.pcap);

            assert_eq!(
                test_case.expected_ja3_str,
                format!("{ja3}"),
                "pcap: {}",
                test_case.pcap,
            );

            assert_eq!(
                test_case.expected_ja3_hash,
                format!("{ja3:x}"),
                "pcap: {}",
                test_case.pcap,
            );
        }
    }
}
